clear all
close all
clc

rng(1)                                          % Generate Simulations;
warning('off','all')

%
%
% Sorting with Teams;
% 
% First coded: Job Boerma and Aleh Tsyvinski and Alexander Zimin, June 2023;
% Last edited: JWB, 11-12-23;
%
%

%% 1. Data

data_filename       = 'data/firm_inequality.xlsx';
params_filename     = 'data/parameters.xlsx';

data                = xlsread(data_filename,'stata');
wages_2013          = data(:, 6)';
wages_1981          = data(:, 1)';

time                = 1;

%% 2. Parameters

nw                  = 5;
m                   = 10000;
N                   = 99*m+1;

% All "1981" or "2013";
% Firm   Effect: 1981/2013/1981;
% Worker Effect: 2013/1981/2013;
% To generate the results of the paper, see the documentation in the README
% file.
workers             = "2013";
firms               = "2013";
target              = "2013";

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%;
%%%%%%%%%%%%% N.B. No need to change anything past this point %%%%%%%%%%%%%;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%;

% Read Data Sheet;
sheet_name          = workers + '_' + firms +  '_' + target;
data                = xlsread(params_filename, sheet_name);

% Reading Beta Distributions;
row_data            = data(nw-1, :);
worker_a            = row_data(2);
worker_b            = row_data(3);
firm_a              = row_data(4);
firm_b              = row_data(5);

% To overwrite: manually enter parameters here:
% firm_a              = 1.6;
% firm_b              = 1.6;
% worker_a            = 1.8;
% worker_b            = 1.4;

if target == "1981"
    w = wages_1981;
else
    w = wages_2013;
end

%% 3. Sorting

matrix = betarnd(firm_a, firm_b, N, 1);
for j = 1:nw
    matrix = cat(2, matrix, betarnd(worker_a, worker_b, N, 1));
end

rearranged_matrix           = rearrange(matrix);
ids                         = get_sorting_ids(rearranged_matrix);
wages                       = get_wages(rearranged_matrix);

original_points             = linspace(0, 1, N);
desired_points              = linspace(0, 1, 100);
w_pred                      = interp1(original_points, wages(ids(end:-1:1, 2), 2), desired_points, 'pchip');
params                      = fit_data(w, w_pred);
p0                          = params(1);
p1                          = params(2);

wages                       = p1*wages + p0;       

%% 4. Statistics

disp('--------------------------------------------------');
disp('---------- Average Wage Coworkers-----------------');
disp('--------------------------------------------------');

averages                    = fliplr(get_mean_log_wages(wages, ids));
fprintf( ...
    "A25\t%f\tA50\t%f\tA75\t%f\tA90\t%f\n \n", ...
    averages(25), averages(50), averages(75), averages(90) ...
);

for i = 1:nw+1
    wages(:, i)             = wages(ids(:, 1), i);
end 

disp('--------------------------------------------------');
disp('------------- Decomposition ----------------------');
disp('--------------------------------------------------');

[total, between, within]    = get_variances(wages);
frac_between                = between/(between+within)
frac_within                 = within/(between+within)

%% 5. Figures;

% Data
x                           = linspace(0, 99, 100);
y2                          = log(p0 + p1 * w_pred);
y1                          = log(w);

% Create Plot
figure;                     % Create a new figure
hold on;                    % Keep the plot active so you can plot multiple lines
plot(x, y1, 'k-', 'LineWidth', 2); 
plot(x, y2, '--', 'Color', [1, 0.5, 0], 'LineWidth', 3); 
hold off; 
lgd = legend('Data', 'Model');
lgd.Orientation = 'Horizontal';
lgd.Box = 'off';

%% 6. Export Wage Distribution Data to Stata;

export                      = 0;                    % To export dataset;

if export == 1;

dataset                     = [x;y1;y2]';

if time == 0;
    delete '/Users/jobboerma/Dropbox/Aleh_Job/data/dataset_beta0.xlsx';
    filename = '/Users/jobboerma/Dropbox/Aleh_Job/data/dataset_beta0.xlsx';
    writematrix(dataset,filename);
elseif time == 1;
    delete '/Users/jobboerma/Dropbox/Aleh_Job/data/dataset_beta1.xlsx';
    filename = '/Users/jobboerma/Dropbox/Aleh_Job/data/dataset_beta1.xlsx';    
    writematrix(dataset,filename);
end

end;

%% 7. Modules

function params = fit_data(omega, omega_pred)
    % Define the model function
    model = @(p, t) log(exp(p(1)) + exp(p(2)) * t);

    % Initial guess
    p0 = [1.0, 1.0];

    % Use MATLAB's lsqcurvefit function
    lb = []; % No lower bounds
    ub = []; % No upper bounds
    fitted_params = lsqcurvefit(model, p0, omega_pred, log(omega), lb, ub);

    % Return the fitted parameters
    params = exp(fitted_params);
end

function rearranged_matrix = rearrange(matrix, num_steps)
    arguments
        matrix
        num_steps = 100
    end
    [~, dim] = size(matrix);
    rearranged_matrix = matrix(:, :);
    for i = 1:num_steps
        for j = 1:dim
            j = mod(j, dim) + 1;
            marginal_product = prod(rearranged_matrix, 2) ./ rearranged_matrix(:, j);
            [~, idx] = sort(marginal_product, 'descend');
            rearranged_matrix(idx, j) = sort(rearranged_matrix(:, j));
        end
    end
end

function ids = get_sorting_ids(rearranged_matrix)
    [n, dim] = size(rearranged_matrix);
    ids = zeros(n, dim);
    for j = 1:dim
        [~, ids(:, j)] = sort(rearranged_matrix(:, j));
    end
end


function wages = get_wages(rearranged_matrix)
    [n, dim] = size(rearranged_matrix);
    ids = get_sorting_ids(rearranged_matrix);
    wages = zeros(n, dim);
    for j = 1:dim
        marginal_product = prod(rearranged_matrix, 2) ./ rearranged_matrix(:, j);
        y = marginal_product(ids(:, j));
        x = rearranged_matrix(ids(:, j), j);
        dx = x(2:end) - x(1:end-1);
        wages(ids(2:end, j), j) = -cumsum(dx .* (y(2:end) + y(1:end-1)) / 2);

        % make minimum equal to 0
        wages(:, j) = wages(:, j) - min(wages(:, j));
    end
    
end


function [total, between, within] = get_variances(wages)
    workers_log_wages = log(wages(:, 2:end));
    within = mean(var(workers_log_wages, 1, 2));
    between = var(mean(workers_log_wages, 2), 1);
    total = var(workers_log_wages(:, 1), 1);
end

function averages = get_mean_log_wages(wages, ids)
    [n, ~] = size(wages);
    mean_log_wages = mean(log(wages(:, 2:end)), 2);
    mean_log_wages = mean_log_wages(ids(:, 2));
    
    averages = zeros(1, 100);
    step = n / 100;
    for i = 1:100
        averages(i) = mean(mean_log_wages((i-1) * step + 1:i * step));
    end
end

function beta_params = read_beta_params(filename)
    variable_names_types = [
        ["Year", "uint32"]; ...
        ["NumWorkers", "uint32"]; ...
        ["DistributionType", "string"]; ...
        ["Alpha", "double"]; ...
        ["Beta", "double"]; ...
        ["P0", "double"]; ...
        ["Scale", "double"]; ...
    ];
    beta_params = table( ...
        'Size', [0, size(variable_names_types, 1)],... 
	    'VariableNames', variable_names_types(:,1),...
	    'VariableTypes', variable_names_types(:,2)...
    );
    
    for year = [1981 2013]
        data = xlsread(filename, string(year));
        [n, ~] = size(data);
        for i = 1:n
            row = struct;
            
            row.Year = year;
            row.NumWorkers = data(i, 1);
            row.DistributionType = "worker";
            row.Alpha = data(i, 2);
            row.Beta = data(i, 3);
            row.P0 = data(i, 6);
            row.Scale = data(i, 7);
            beta_params = [beta_params; struct2table(row)];

            row.DistributionType = "firm";
            row.Alpha = data(i, 4);
            row.Beta = data(i, 5);
            beta_params = [beta_params; struct2table(row)];
        end
    end
end

function row = get_row_id(params, year, num_workers, dist_type)
    row = find( ...
        params.Year == year ...
        & params.NumWorkers == num_workers ...
        & params.DistributionType == dist_type ...
    );
    row = row(1);
end

function I = get_beta_sample(params, year, num_workers, dist_type, num_points)
    N = num_points;
    row = get_row_id(params, year, num_workers, dist_type);
    x = linspace(1 / (2*N), 1 - 1 / (2*N), N);
    I = betainv(x, params.Alpha(row), params.Beta(row));
end

